{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": "2.0.0\n"
    }
   ],
   "source": [
    "from __future__ import absolute_import, division, print_function, unicode_literals\n",
    "\n",
    "# TensorFlow and tf.keras\n",
    "import tensorflow as tf\n",
    "\n",
    "import os\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "from tensorflow import keras\n",
    "\n",
    "# Helper libraries\n",
    "import imageio\n",
    "import numpy as np\n",
    "import time\n",
    "import matplotlib.pyplot as plt\n",
    "from IPython import display\n",
    "from scipy import signal\n",
    "print(tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "feat_label = np.load('ahadb1.npy')\n",
    "trainECG   = feat_label[:,:-1]\n",
    "trainECG = signal.resample(trainECG,400,axis=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainECG = np.expand_dims(trainECG,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": "(1125, 400, 1)"
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "trainECG.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "BUFFER_SIZE = 60000\n",
    "BATCH_SIZE = 10\n",
    "inputLength = 400\n",
    "\n",
    "train_dataset = tf.data.Dataset.from_tensor_slices(trainECG).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 构建模型\n",
    "# 生成器\n",
    "def build_generator():\n",
    "    generator = tf.keras.Sequential()\n",
    "    generator.add(layers.Dense(inputLength*BATCH_SIZE, input_shape=(inputLength,1),dtype='float32'))\n",
    "    generator.add(layers.Conv1D(filters = 32,kernel_size = 3, padding = 'same'))\n",
    "    generator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    generator.add(layers.Conv1D(filters = 32,kernel_size = 3, padding = 'same'))\n",
    "    generator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    generator.add(layers.Conv1D(filters = 32,kernel_size = 3, padding = 'same'))\n",
    "    generator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    generator.add(layers.Conv1D(filters = 32,kernel_size = 3, padding = 'same'))\n",
    "    generator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    generator.add(layers.Conv1D(filters = 32,kernel_size = 3, padding = 'same'))\n",
    "    generator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    generator.add(layers.Conv1D(filters = 32,kernel_size = 3, padding = 'same'))\n",
    "    generator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    generator.add(layers.Conv1D(filters = 32,kernel_size = 5, padding = 'same'))\n",
    "    generator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    generator.add(layers.Conv1D(filters = 32,kernel_size = 5, padding = 'same'))\n",
    "    generator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    generator.add(layers.Conv1D(filters = 1,kernel_size = 5, activation = 'tanh', padding = 'same'))\n",
    "\n",
    "    return generator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 判别器\n",
    "def build_discriminator():\n",
    "\n",
    "    discriminator = tf.keras.Sequential()\n",
    "    discriminator.add(layers.Conv1D(filters = 32,kernel_size = 3, padding = 'same',input_shape=(inputLength,1)))\n",
    "    discriminator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    discriminator.add(layers.Conv1D(filters = 32,kernel_size = 3, padding = 'same')) \n",
    "    discriminator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    discriminator.add(layers.MaxPooling1D(pool_size=2))\n",
    "    discriminator.add(layers.Conv1D(filters = 32,kernel_size = 3, padding = 'same')) \n",
    "    discriminator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    discriminator.add(layers.Conv1D(filters = 32,kernel_size = 3, padding = 'same')) \n",
    "    discriminator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    discriminator.add(layers.MaxPooling1D(pool_size=2))\n",
    "    discriminator.add(layers.Conv1D(filters = 32,kernel_size = 3, padding = 'same')) \n",
    "    discriminator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    discriminator.add(layers.Conv1D(filters = 32,kernel_size = 3, padding = 'same')) \n",
    "    discriminator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    discriminator.add(layers.MaxPooling1D(pool_size=2))\n",
    "    discriminator.add(layers.Flatten(input_shape=(inputLength,1)))\n",
    "    discriminator.add(layers.Dense(64,dtype='float32'))\n",
    "    discriminator.add(layers.Dropout(0.4))\n",
    "    discriminator.add(layers.LeakyReLU(alpha=0.2))\n",
    "    discriminator.add(layers.Dense(1, activation='tanh',dtype='float32'))\n",
    "\n",
    "    return discriminator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 实例化\n",
    "generator = build_generator()\n",
    "discriminator = build_discriminator()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "#测试运行  通过\n",
    "noise = tf.random.normal([20,400,1])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "generateECG = generator(noise)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "generateECG"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "discriminator(generateECG)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义优化器和损失函数\n",
    "cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 判别损失 判别器要做两件事情，既要真的趋近于1，又要假的趋近于0\n",
    "def discriminator_loss(real_output, fake_output):\n",
    "    real_loss = cross_entropy(-0.9*tf.ones_like(real_output), real_output)\n",
    "    fake_loss = cross_entropy(0.9*tf.ones_like(fake_output), fake_output)\n",
    "    total_loss = 0.4*real_loss+0.6*fake_loss\n",
    "    return total_loss\n",
    "\n",
    "# 生成损失  生成器使得假的趋近于1\n",
    "def generator_loss(fake_output):\n",
    "    return cross_entropy(-0.9*tf.ones_like(fake_output), fake_output)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义优化器\n",
    "generator_optimizer = tf.keras.optimizers.Adam(lr=0.00001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)\n",
    "discriminator_optimizer = tf.keras.optimizers.Adam(lr=0.00001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)\n",
    "total_optimizer = tf.keras.optimizers.Adam(lr=0.00001, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义训练\n",
    "EPOCHS = 20\n",
    "noise_dim = inputLength\n",
    "num_examples_to_generate = 5\n",
    "\n",
    "\n",
    "# 我们将重复使用该种子（因此在动画 GIF 中更容易可视化进度）\n",
    "seed = tf.random.normal([num_examples_to_generate, noise_dim, 1],dtype='float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "seed = tf.cos(4*seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 单步训练\n",
    "# 注意 `tf.function` 的使用\n",
    "# 该注解使函数被“编译”\n",
    "@tf.function\n",
    "def train_step(ECG):\n",
    "    noise = tf.random.normal([BATCH_SIZE, noise_dim,1])\n",
    "    noise = tf.cos(4*noise)\n",
    "    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape,tf.GradientTape() as total_tape:\n",
    "      generated_ECG = generator(noise, training=True)\n",
    "\n",
    "      real_output = discriminator(ECG, training=True)\n",
    "      fake_output = discriminator(generated_ECG, training=True)\n",
    "\n",
    "      gen_loss = generator_loss(fake_output)\n",
    "      disc_loss = discriminator_loss(real_output, fake_output)\n",
    "      total_loss = tf.tanh(tf.abs(gen_loss)-tf.abs(disc_loss))\n",
    "    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)\n",
    "    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)\n",
    "    gradients_of_total = total_tape.gradient(total_loss, generator.trainable_variables)\n",
    "\n",
    "    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))\n",
    "    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))\n",
    "    total_optimizer.apply_gradients(zip(gradients_of_total,generator.trainable_variables))\n",
    "    return gen_loss,disc_loss,total_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 定义训练\n",
    "def train(dataset, epochs):\n",
    "    for epoch in range(epochs):\n",
    "        start = time.time()\n",
    "        cunt = 0\n",
    "        for image_batch in dataset:\n",
    "            cunt+=1\n",
    "            gen_loss,disc_loss,total_loss = train_step(image_batch)\n",
    "            if cunt%100==0:\n",
    "                # 继续进行时为 GIF 生成图像\n",
    "                display.clear_output(wait=True)\n",
    "                generate_and_save_images(generator,\n",
    "                                            epoch + \n",
    "                                            1,\n",
    "                                            seed)\n",
    "                print ('gen loss {} __ dis loss {} __tol loss{}'.format(gen_loss, disc_loss, total_loss))\n",
    "        print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))\n",
    "\n",
    "    # 最后一个 epoch 结束后生成图片\n",
    "    display.clear_output(wait=True)\n",
    "    generate_and_save_images(generator,\n",
    "                            epochs,\n",
    "                            seed)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_and_save_images(model, epoch, test_input):\n",
    "    predictions = model(test_input, training=False)\n",
    "    fig, axs = plt.subplots(5, 1)\n",
    "    for i in range(5):\n",
    "        plot_sigs = predictions[i].numpy()\n",
    "        axs[i].plot(plot_sigs)\n",
    "        axs[i].axis('off')\n",
    "   \n",
    "    plt.savefig('images\\image_at_epoch_{:04d}.png'.format(epoch))\n",
    "    plt.show()\n",
    "#     plt.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "train(train_dataset, 500)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "generator.save('generator_model3.h5')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "noise = tf.random.normal([BATCH_SIZE, noise_dim,1])\n",
    "noise = tf.cos(4*noise)\n",
    "epoch  = epoch+1\n",
    "generate_and_save_images(generator, epoch, noise)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用 epoch 数生成单张图片\n",
    "def display_image(epoch_no):\n",
    "  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {},
   "outputs": [],
   "source": [
    "anim_file = 'dcgan.gif'\n",
    "\n",
    "with imageio.get_writer(anim_file, mode='I') as writer:\n",
    "  filenames = glob.glob('image*.png')\n",
    "  filenames = sorted(filenames)\n",
    "  last = -1\n",
    "  for i,filename in enumerate(filenames):\n",
    "    frame = 2*(i**0.5)\n",
    "    if round(frame) > round(last):\n",
    "      last = frame\n",
    "    else:\n",
    "      continue\n",
    "    image = imageio.imread(filename)\n",
    "    writer.append_data(image)\n",
    "  image = imageio.imread(filename)\n",
    "  writer.append_data(image)\n",
    "\n",
    "import IPython\n",
    "if IPython.version_info > (6,2,0,''):\n",
    "  display.Image(filename=anim_file)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}